# ======================================================================
# Copyright CERFACS (June 2019)
# Contributor: Adrien Suau (adrien.suau@cerfacs.fr)
#
# This software is governed by the CeCILL-B license under French law and
# abiding  by the  rules of  distribution of free software. You can use,
# modify  and/or  redistribute  the  software  under  the  terms  of the
# CeCILL-B license as circulated by CEA, CNRS and INRIA at the following
# URL "http://www.cecill.info".
#
# As a counterpart to the access to  the source code and rights to copy,
# modify and  redistribute granted  by the  license, users  are provided
# only with a limited warranty and  the software's author, the holder of
# the economic rights,  and the  successive licensors  have only limited
# liability.
#
# In this respect, the user's attention is drawn to the risks associated
# with loading,  using, modifying and/or  developing or reproducing  the
# software by the user in light of its specific status of free software,
# that  may mean  that it  is complicated  to manipulate,  and that also
# therefore  means that  it is reserved for  developers and  experienced
# professionals having in-depth  computer knowledge. Users are therefore
# encouraged  to load and  test  the software's  suitability as  regards
# their  requirements  in  conditions  enabling  the  security  of their
# systems  and/or  data to be  ensured and,  more generally,  to use and
# operate it in the same conditions as regards security.
#
# The fact that you  are presently reading this  means that you have had
# knowledge of the CeCILL-B license and that you accept its terms.
# ======================================================================

from bisect import insort as insert_sorted

from qat.lang.AQASM.gates import X, CCNOT

from qaths.routines.comparator import compare
from qaths.utils.NamedRoutine import NamedRoutine


class Segment:
    """Represent a segment with a constant value on it."""

    def __init__(self, start: int, stop: int, value):
        self.start = start
        self.stop = stop
        self.value = value
        assert self.start < self.stop, (
            "A segment should not be empty and should "
            "always have a starting point strictly inferior to its stop point."
        )

    def __cmp__(self, other: "Segment"):
        if self.stop < other.start:
            return self.stop - self.start
        elif self.start > other.stop:
            return self.start - other.stop
        elif self == other:
            return 0
        else:
            return RuntimeError("Non comparable segments: {} VS {}".format(self, other))

    def __eq__(self, other: "Segment"):
        return (
            self.start == other.start
            and self.stop == other.stop
            and self.value == other.value
        )

    def encode(
        self,
        input_size: int,
        value_size: int,
        value_encoder,
        start_already_computed: bool = False,
        stop_should_be_reused: bool = False,
    ) -> NamedRoutine:
        r"""Construct a QRoutine encoding the segment.

        Let `i = input_size`, `v = value_size`
        The returned routine needs qubits organised as follow: ::

            |        |x>        |        |y>        |        |a>        |
            |   .   .   .   .   |   .   .   .   .   |   .   .   .   .   |
            |   0          i-1  |   i         i+v-1 |  i+v        i+v+2 |

        :param input_size: Size of the :math:`\ket{x}` register used as index.
        :param value_size: Size of the register that will store the value on the
            segment. This size should be large enough to encode the value provided.
        :param value_encoder: A function that takes a value `val` and the number of
            qubits this value should be encoded on `size` and return a routine that
            initialise a quantum register of `size` qubits to the value `val`.
        :param start_already_computed: A boolean flag set to True if the comparison with
            the lower limit (start) of the segment has already been computed in a[0].
            If False, the comparison is computed.
        :param stop_should_be_reused: A boolean flag set to True if the comparison with
            the upper limit (stop) of the segment should not be uncomputed at the end of
            the calculations and left in a[1]. If False, the comparison is uncomputed.
        :return: a routine encoding the segment.
        """
        x = list(range(input_size))
        y = list(range(input_size, input_size + value_size))
        low, up, ancilla = list(
            range(input_size + value_size, input_size + value_size + 3)
        )
        rout = NamedRoutine("Segment.encode", arity=input_size + value_size + 3)

        # |a[0]>  :=  |x>  >=  segment.start
        if not start_already_computed:
            rout.protected_apply(compare(input_size, self.start), x, ancilla, low)
            rout.protected_apply(X, low)
        # |a[1]>  :=  |x>  <  segment.stop
        rout.protected_apply(compare(input_size, self.stop), x, ancilla, up)
        # |a[2]>  :=  |a[0]>  &&  |a[1]>
        rout.protected_apply(CCNOT, low, up, ancilla)

        rout.apply(value_encoder(self.value, value_size).ctrl(), ancilla, y)

        # Uncompute
        rout.protected_apply(CCNOT.dag(), low, up, ancilla)
        if stop_should_be_reused:
            rout.protected_apply(compare(input_size, self.stop).dag(), x, ancilla, up)
        rout.protected_apply(X.dag(), low)
        rout.protected_apply(compare(input_size, self.start).dag(), x, ancilla, low)

        return rout


class Segments:
    """Represents multiple segments, each of them having a constant value."""

    def __init__(self, start, stop, segments=None, default_value=None):
        if segments is None:
            segments = []
        self._start = start
        self._stop = stop
        self._segments = sorted(segments)
        self._default_value = default_value

    def add_segment(self, segment: Segment):
        insert_sorted(self._segments, segment)

    @property
    def segments(self):
        return self._segments

    def encode(self, input_size: int, value_size: int, value_encoder) -> NamedRoutine:
        r"""Construct a QRoutine encoding the segments.

        Let `i = input_size`, `v = value_size`
        The returned routine needs qubits organised as follow: ::

            |        |x>        |        |y>        |        |a>        |
            |   .   .   .   .   |   .   .   .   .   |   .   .   .   .   |
            |   0          i-1  |   i         i+v-1 |  i+v        i+v+2 |

        :param input_size: Size of the :math:`\ket{x}` register used as index.
        :param value_size: Size of the register that will store the value on the
            segment. This size should be large enough to encode the value provided.
        :param value_encoder: A function that takes a value `val` and the number of
            qubits this value should be encoded on `size` and return a routine that
            initialise a quantum register of `size` qubits to the value `val`.
        :return: a routine encoding the segment.
        """
        x = list(range(input_size))
        y = list(range(input_size, input_size + value_size))
        low, up, ancilla = list(
            range(input_size + value_size, input_size + value_size + 3)
        )
        rout = NamedRoutine("Segments.encode", arity=input_size + value_size + 3)

        # The stop_saved variable will be True if the last iteration kept the comparison
        # in order to reuse it a the next iteration.
        # This variable is initialised to "self._segments[0].start == 0" because:
        #  1. The ancilla qubit "low" stores the comparison "|x> >= start" by
        #     definition.
        #  2. The ancilla qubits are initialised to 0. For "low" at this point in
        #     the code, "|low> == |0>" translates to "every |x> is above or equal to
        #     self._segments[0].start", which is the case if and only if
        #     self._segments[0].start == 0.
        start_saved = self._segments[0].start == 0

        for seg_id in range(len(self._segments)):
            if start_saved:
                # If we did not uncompute the "stop" at the end of the previous loop,
                # then the ancilla "up" contains
                # "|x>  <  self._segments[seg_id].stop" which is the exact opposite of
                # what we want here, that is why we negate.
                # Moreover, the up and low ancilla should be switched to keep the
                # naming consistent.
                low, up = up, low
                rout.apply(X, low)
            # Check if we should save the stop comparison for the next iteration
            # The stop comparison should be saved if all the following points are
            # verified:
            #  - this is not the last segment
            #  - the following segment starts just after the end of the current segment
            save_next = ((seg_id + 1) < len(self._segments)) and (
                self._segments[seg_id].stop == self._segments[seg_id + 1].start
            )
            rout.apply(
                self._segments[seg_id].encode(
                    input_size,
                    value_size,
                    value_encoder,
                    start_already_computed=start_saved,
                    stop_should_be_reused=save_next,
                ),
                x,
                y,
                low,
                up,
                ancilla,
            )
            # Update the "start_saved" variable and loop
            start_saved = save_next
        return rout
